import torch as T
import numpy as np
from matplotlib import pyplot as pt
import models
import util
device = T.device('cuda:3')
# load data and visualize the image
img, lc, nlcd = util.load_data()
lc_oh = np.array([lc==i for i in range(4)])
nlcd_oh = np.array([nlcd==util.nlcd_cl[i] for i in range(22)])
pt.figure(figsize=(12,4))
pt.subplot(131);pt.imshow(img[:3].T)
pt.subplot(132);pt.imshow(util.vis_lc(lc_oh).T)
pt.subplot(133);pt.imshow(util.vis_nlcd(nlcd_oh).T)
pt.show()
# initialize model
epitome_size = 299
ep = models.EpitomeModel(epitome_size, 4).to(device)
# train the model (best run on GPU)
# see figure in SI for outputs
n_batches = 10000
batch_size = 256
show_interval = 100
diversify = False # tiny image, no need to
mask_threshold = float('1e-8')
reset_threshold = 0.95
optimizer = T.optim.Adam(ep.parameters(), lr=0.003)
counter = T.zeros((ep.size, ep.size)).to(device)
for it in range(n_batches):
w = np.random.randint(10,16)*2+1 # odd number 21 to 31
#construct the batches
batch = np.zeros((batch_size, 4, w, w))
for b in range(batch_size):
x = np.random.randint(img.shape[1]-w+1)
y = np.random.randint(img.shape[2]-w+1)
batch[b] = img[:,x:x+w,y:y+w]
x = T.from_numpy(batch).to(device, T.float)
optimizer.zero_grad()
# compute p(x|s)p(s) and smooth
e = ep(x) / (w/11)**2
# extract worst-modeled quarter of data
if diversify:
indices = e.logsumexp((0,2,3)).topk(batch_size//4, largest=False, sorted=False).indices
e = e[:,indices]
# increment counters and compute mask
posterior = e.view(-1, ep.size*ep.size).softmax(1).view(-1, ep.size, ep.size).mean(0)
with T.no_grad(): counter += posterior
mask = (counter > mask_threshold).float()
# reset counters if threshold reached
if (mask.mean() > reset_threshold):
counter[:] = 0.
mask[:] = 0.
# compute log likelihood of data over unmasked positions (+const)
loss = -T.logsumexp(e - 10000*mask, (0,2,3))
loss.mean().backward()
optimizer.step()
# clamp parameters
with T.no_grad():
ep.ivar[:].clamp_(min=1, max=10**2)
ep.prior[:] -= ep.prior.mean()
ep.prior[:].clamp_(min=-4., max=4.)
# show the means
if it % show_interval == 0:
pt.imshow(ep.mean.detach().cpu().numpy()[0,:3].T)
pt.show()
T.save(ep.state_dict(), 'epitome')
ep.load_state_dict(T.load('epitome', map_location={'cuda:2':'cuda:3'}))
max_patch_size = 31
hw = max_patch_size//2
def label_embed(labels, vis_fn, show=False):
n_batches = 51
n_samples = 64
ep_map = np.zeros((labels.shape[0], ep.layers, ep.size+max_patch_size, ep.size+max_patch_size)) + 0.00001
T.set_grad_enabled(False)
for it in range(n_batches):
w = np.random.randint(10,16)*2+1 #size of the patch to compute posterior for
ew = 11 #size of the center piece of the patch from which to embed labels
batch = np.zeros((batch_size,4,w,w))
lc_batch = np.zeros((batch_size,labels.shape[0],w,w))
for b in range(batch_size):
x = np.random.randint(img.shape[1]-w+1)
y = np.random.randint(img.shape[2]-w+1)
batch[b] = img[:,x:x+w,y:y+w]
lc_batch[b] = labels[:,x:x+w,y:y+w]
x = T.from_numpy(batch).to(device, T.float)
e = ep(x) / (w/11)**2
temp = max(3-it,1)# take a few samples with higher temperature to fill in the gaps (for pretty picture)
logits = e.transpose(0,1).reshape(batch_size,-1) / temp
dist = T.distributions.Categorical(logits=logits.cpu())
d = (w-ew)//2
shift = (max_patch_size-ew)//2
z = dist.sample([n_samples])
layers = z // (ep.size**2)
cs = z % (ep.size**2)
xs, ys = cs//ep.size, cs%ep.size
for s in range(n_samples):
for j in range(batch_size):
layer,x,y = (a[s,j] for a in (layers,xs,ys))
ep_map[:,layer,x+shift:x+shift+ew,y+shift:y+shift+ew] += lc_batch[j,:,d:w-d,d:w-d]
if it%10==0 and show:
pt.subplot(121)
pt.imshow(ep.mean.detach().cpu().numpy()[0,:3].T)
pt.subplot(122)
pt.imshow(vis_fn(ep_map[:,0,hw:hw+ep.size,hw:hw+ep.size]).T)
pt.show()
# wrap around
ep_map[...,:max_patch_size,:] += ep_map[...,-max_patch_size-1:-1,:]
ep_map[...,-max_patch_size-1:-1,:] = ep_map[...,:max_patch_size,:]
ep_map[...,:,:max_patch_size] += ep_map[...,:,-max_patch_size-1:-1]
ep_map[...,:,-max_patch_size-1:-1] = ep_map[...,:,:max_patch_size]
return ep_map
def superres(lrmap, max_iter=20):
eps=0.00000000001
# use the full p(l,c)
nlcd_mu = util.nlcd_mu
p_l_c = T.from_numpy(nlcd_mu).float().to(device)
# init the p(s|c): renormalize priors, then normalize over positions
p_s_c = (T.from_numpy(lrmap).float() + eps).to(device)
p_s_c /= (p_s_c.sum(0))
p_s_c /= p_s_c.sum((1,2,3)).view(-1,1,1,1)
# init the p(l|s) and q(s|l,c)
p_l_s = (T.rand(p_s_c.shape[1:]+(4,))+10).to(device)
p_l_s /= p_l_s.sum(3).unsqueeze(3)
q = T.empty(p_s_c.shape[1:] + p_l_c.shape)
for it in range(max_iter):
# E step
q = T.einsum('exyl,cexy->exycl',p_l_s,p_s_c) + eps
q /= q.sum((0,1,2))
# M step
p_l_s = T.einsum('exycl,cl->exyl',q,p_l_c)+eps
p_l_s /= p_l_s.sum(3).unsqueeze(3)
return p_l_s.cpu().numpy()
# build embedding map using high-res labels (lc)
ep_map_hr = label_embed(lc_oh, util.vis_lc)
pt.subplot(121)
pt.imshow(ep.mean.detach().cpu().numpy()[0,:3].T)
pt.subplot(122)
pt.imshow(util.vis_lc(ep_map_hr[:,0,hw:hw+ep.size,hw:hw+ep.size]).T)
pt.show()
# build embedding map using low-res labels (nlcd)
ep_map_lr = label_embed(nlcd_oh, util.vis_nlcd)
pt.subplot(121)
pt.imshow(ep.mean.detach().cpu().numpy()[0,:3].T)
pt.subplot(122)
pt.imshow(util.vis_nlcd(ep_map_lr[:,0,hw:hw+ep.size,hw:hw+ep.size]).T)
pt.show()
# super-resolve it to a high-res map
ep_map_sr = superres(ep_map_lr).T.swapaxes(1,3)
pt.subplot(121)
pt.imshow(ep.mean.detach().cpu().numpy()[0,:3].T)
pt.subplot(122)
pt.imshow(util.vis_lc(ep_map_sr[:,0,hw:hw+ep.size,hw:hw+ep.size]).T)
pt.show()
def segment(ep_map, vis_fn=util.vis_lc):
n_batches = 51
n_samples = 64
reconstruction = np.zeros((4,) + img.shape[1:])
counts = np.zeros(img.shape[1:])+0.000001
for it in range(n_batches):
#w = np.random.randint(10,16)*2+1 #size of the patch to compute posterior for
w = 11
# making w smaller will make small features more likely to appear in reconstruction
ew = 11 #size of the center piece of the patch from which to copy labels from ep_map
batch = np.zeros((batch_size,4,w,w))
coords = []
for b in range(batch_size):
x = np.random.randint(img.shape[1]-w+1)
y = np.random.randint(img.shape[2]-w+1)
coords.append((x,y))
batch[b] = img[:,x:x+w,y:y+w]
x = T.from_numpy(batch).to(device, T.float)
e = ep(x) / (w/11)**2
logits = e.transpose(0,1).reshape(batch_size,-1)
dist = T.distributions.Categorical(logits=logits.cpu())
d = (w-ew)//2
shift = (max_patch_size-ew)//2
z = dist.sample([n_samples])
layers = z // (ep.size**2)
cs = z % (ep.size**2)
xs, ys = cs//ep.size, cs%ep.size
for s in range(n_samples):
for j in range(batch_size):
layer,x,y = (a[s,j] for a in (layers,xs,ys))
cx,cy = coords[j]
reconstruction[:,cx+d:cx+d+ew,cy+d:cy+d+ew] += ep_map[:,layer,x+shift:x+shift+ew,y+shift:y+shift+ew]
counts[cx+d:cx+d+ew,cy+d:cy+d+ew] += 1
if it%10==0:# and show:
pt.figure(figsize=(12,4))
pt.subplot(131);pt.title('image')
pt.imshow(img[:3].T)
pt.subplot(132);pt.title('prediction')
pt.imshow(vis_fn(reconstruction/counts).T)
pt.subplot(133);pt.title('gt')
pt.imshow(util.vis_lc(lc_oh).T)
pt.show()
return reconstruction/counts
# segment using the hr-derived epitome embedding
rec_hr = segment(ep_map_hr)
# segment using the nlcd+sr-derived epitome embedding
rec_sr = segment(ep_map_sr)
ep_map_rgbi = ep.mean[0].detach().cpu().numpy()
s = max_patch_size // 2
ep_map_rgbi = np.concatenate( [ ep_map_rgbi[:,-s:,:], ep_map_rgbi, ep_map_rgbi[:,:s,:] ], 1 )
ep_map_rgbi = np.concatenate( [ ep_map_rgbi[:,:,-s:], ep_map_rgbi, ep_map_rgbi[:,:,:s] ], 2 )
rec_img = segment(ep_map_rgbi[:,None], vis_fn=lambda x:x[:3])
pt.figure(figsize=(15,5))
pt.subplot(131)
pt.imshow(img[:3].T)
pt.subplot(132)
pt.imshow(rec_img[:3].T)
pt.subplot(133)
pt.imshow((((img-rec_img)**2).sum(0) * (rec_img.sum(0)>0)).T, cmap='Reds', interpolation='none')